{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "e54974c3",
      "metadata": {
        "id": "e54974c3"
      },
      "source": [
        "# Few-shot Relation Extraction Tutorial\n",
        "\n",
        "> Tutorial作者: 黎洲波（zhoubo.li@zju.edu.cn）\n",
        "\n",
        "In this tutorial, we use [KnowPrompt](https://arxiv.org/abs/2104.07650v5) to extract relational triples after being trained on few-shot datasets. We hope this tutorial can help you understand the process of few-shot relation extraction.\n",
        "\n",
        "This tutorial uses `Python3`."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "dc428237",
      "metadata": {
        "id": "dc428237"
      },
      "source": [
        "## RE\n",
        "**Relation extraction** (RE), a key task in information extraction, predicts semantic relations between pairs of entities from unstructured\n",
        "texts.\n",
        "\n",
        "## Few-shot RE\n",
        "Few-shot relation extraction in DeepKE is based on the *pre-train, prompt, and predict* paradigm, feeds Prompt parameters into attention and cross-attention in BERT and fine-tunes them on few-shot datasets, which obtains excellent performance on the low-resource scenario. The prompt-tuning method is shown in the following picture:\n",
        "![关系抽取中的Prompt-tuning](img/img1.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7f79ef50",
      "metadata": {
        "id": "7f79ef50"
      },
      "source": [
        "## Dataset\n",
        "There are some few-shot RE datasets including RETACRED, SEMEVAL, TACREV, WIKI80, .etc. The tutorial uses [SEMEVAL](https://semeval2.fbk.eu/semeval2.php?location=tasks#T11), which is from [SemEval-2010 Task 8: Multi-Way Classification of Semantic Relations between Pairs of Nominals](https://arxiv.org/abs/1911.10422). The structure of the dataset folder `./data/` is as follow:\n",
        "\n",
        "```\n",
        ".\n",
        "├── rel2id.json                     # Relation Label - ID Map\n",
        "├── temp.txt                        # Relation Label\n",
        "├── test.txt                        # Test Set\n",
        "├── train.txt                       # Training Set\n",
        "└── val.txt                         # Validation Set\n",
        "```\n",
        "\n",
        "The data formats of SEMEVAL are described as follow:\n",
        "\n",
        "```\n",
        "Data Format:\n",
        "{\n",
        "    'token': [tokens in a sentence],\n",
        "    \"h\": {\n",
        "        \"name\": mention_name,\n",
        "        \"pos\" : [postion of mention in a sentence]\n",
        "    },\n",
        "    \"t\": {\n",
        "        \"name\": mention_name,\n",
        "        \"pos\" : [postion of mention in a sentence]\n",
        "    },\n",
        "    \"relation\": relation\n",
        "}\n",
        "```\n",
        "There are 9+1 relation types in SEMEVAL and their proportions are shown in the following table:\n",
        "\n",
        "![数据集数据占比](img/img2.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f7307aab",
      "metadata": {
        "id": "f7307aab"
      },
      "source": [
        "## KnowPrompt\n",
        "In DeepKE, we use the Prompt method that can parse relational labels semantically, which is called Knowledge-aware Prompt-tuning (KnowPrompt). The frameworks of Fine-tuning (Fig. a), Prompt-tuning (Fig. b) and KnowPrompt (Fig. c) we use are in the following picture. The answer words in Prompt refer to virtual answer words.\n",
        "\n",
        "![低资源关系抽取架构图](img/img3.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "52d41b71",
      "metadata": {
        "id": "52d41b71"
      },
      "source": [
        "## Prepare the runtime environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b6ff3e4a",
      "metadata": {
        "id": "b6ff3e4a"
      },
      "outputs": [],
      "source": [
        "!pip install deepke\n",
        "!wget 120.27.214.45/Data/re/few_shot/data.tar.gz\n",
        "!tar -xzvf data.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "55ec76b1",
      "metadata": {
        "id": "55ec76b1"
      },
      "source": [
        "## Import modules"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2812166f",
      "metadata": {
        "id": "2812166f"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import json\n",
        "import csv\n",
        "import time\n",
        "import pickle\n",
        "import logging\n",
        "import shutil\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "from functools import partial\n",
        "from collections import Counter\n",
        "from collections import OrderedDict\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.nn.utils.rnn import pad_sequence\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
        "from transformers.modeling_utils import PreTrainedModel\n",
        "from transformers.optimization import AdamW, get_linear_schedule_with_warmup\n",
        "from transformers import BertTokenizer, AutoModelForMaskedLM\n",
        "\n",
        "logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', \n",
        "                    datefmt = '%m/%d/%Y %H:%M:%S',\n",
        "                    level = logging.INFO)\n",
        "logger = logging.getLogger(__name__)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6c687c19",
      "metadata": {
        "id": "6c687c19"
      },
      "source": [
        "## Config parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2befe17f",
      "metadata": {
        "id": "2befe17f"
      },
      "outputs": [],
      "source": [
        "class Config(object):\n",
        "    accelerator = None\n",
        "    accumulate_grad_batches = 1\n",
        "    amp_backend = 'native'\n",
        "    amp_level = 'O2'\n",
        "    auto_lr_find = False\n",
        "    auto_scale_batch_size = False\n",
        "    auto_select_gpus = False\n",
        "    batch_size = 16\n",
        "    benchmark = False\n",
        "    check_val_every_n_epoch = '3'\n",
        "    checkpoint_callback = True\n",
        "    data_class = 'REDataset'\n",
        "    data_dir = 'data/'\n",
        "    default_root_dir = None\n",
        "    deterministic = False\n",
        "    devices = None\n",
        "    distributed_backend = None\n",
        "    fast_dev_run = False\n",
        "    flush_logs_every_n_steps = 100\n",
        "    gpus = None\n",
        "    gradient_accumulation_steps = 1\n",
        "    gradient_clip_algorithm = 'norm'\n",
        "    gradient_clip_val = 0.0\n",
        "    ipus = None\n",
        "    limit_predict_batches = 1.0\n",
        "    limit_test_batches = 1.0\n",
        "    limit_train_batches = 1.0\n",
        "    limit_val_batches = 1.0\n",
        "    litmodel_class = 'BertLitModel'\n",
        "    load_checkpoint = None\n",
        "    log_dir = './model_bert.log'\n",
        "    log_every_n_steps = 50\n",
        "    log_gpu_memory = None\n",
        "    logger = True\n",
        "    lr = 3e-05\n",
        "    lr_2 = 3e-05\n",
        "    max_epochs = '30'\n",
        "    max_seq_length = 256\n",
        "    max_steps = None\n",
        "    max_time = None\n",
        "    min_epochs = None\n",
        "    min_steps = None\n",
        "    model_class = 'BertForMaskedLM'\n",
        "    model_name_or_path = 'bert-base-uncased'\n",
        "    move_metrics_to_cpu = False\n",
        "    multiple_trainloader_mode = 'max_size_cycle'\n",
        "    num_nodes = 1\n",
        "    num_processes = 1\n",
        "    num_sanity_val_steps = 2\n",
        "    num_train_epochs = 30\n",
        "    num_workers = 8\n",
        "    optimizer = 'AdamW'\n",
        "    overfit_batches = 0.0\n",
        "    plugins = None\n",
        "    precision = 32\n",
        "    prepare_data_per_node = True\n",
        "    process_position = 0\n",
        "    profiler = None\n",
        "    progress_bar_refresh_rate = None\n",
        "    ptune_k = 7\n",
        "    reload_dataloaders_every_epoch = False\n",
        "    reload_dataloaders_every_n_epochs = 0\n",
        "    replace_sampler_ddp = True\n",
        "    resume_from_checkpoint = None\n",
        "    save_path = './model_bert.pt'\n",
        "    seed = 666\n",
        "    stochastic_weight_avg = False\n",
        "    sync_batchnorm = False\n",
        "    t_lambda = 0.001\n",
        "    task_name = 'wiki80'\n",
        "    terminate_on_nan = False\n",
        "    tpu_cores = None\n",
        "    track_grad_norm = -1\n",
        "    train_from_saved_model = ''\n",
        "    truncated_bptt_steps = None\n",
        "    two_steps = False\n",
        "    use_prompt = True\n",
        "    val_check_interval = 1.0\n",
        "    wandb = False\n",
        "    weight_decay = 0.01\n",
        "    weights_save_path = None\n",
        "    weights_summary = 'top'\n",
        "    load_path = './model_bert.pt'\n",
        "    \n",
        "cfg = Config()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "11897fb8",
      "metadata": {
        "id": "11897fb8"
      },
      "source": [
        "## Preprocess Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b209d40f",
      "metadata": {
        "id": "b209d40f"
      },
      "outputs": [],
      "source": [
        "class InputExampleWiki80(object):\n",
        "    \"\"\"A single training/test example for span pair classification.\"\"\"\n",
        "\n",
        "    def __init__(self, guid, sentence, span1, span2, ner1, ner2, label):\n",
        "        self.guid = guid\n",
        "        self.sentence = sentence\n",
        "        self.span1 = span1\n",
        "        self.span2 = span2\n",
        "        self.ner1 = ner1\n",
        "        self.ner2 = ner2\n",
        "        self.label = label\n",
        "\n",
        "class DataProcessor(object):\n",
        "    \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
        "\n",
        "    def get_train_examples(self, data_dir):\n",
        "        \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    def get_dev_examples(self, data_dir):\n",
        "        \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    def get_labels(self):\n",
        "        \"\"\"Gets the list of labels for this data set.\"\"\"\n",
        "        raise NotImplementedError()\n",
        "\n",
        "    @classmethod\n",
        "    def _read_tsv(cls, input_file, quotechar=None):\n",
        "        \"\"\"Reads a tab separated value file.\"\"\"\n",
        "        with open(input_file, \"r\") as f:\n",
        "            reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
        "            lines = []\n",
        "            for line in reader:\n",
        "                lines.append(line)\n",
        "            return lines\n",
        "\n",
        "class wiki80Processor(DataProcessor):\n",
        "    \"\"\"Processor for the TACRED data set.\"\"\"\n",
        "    def __init__(self, tokenizer, data_path, use_prompt):\n",
        "        super().__init__()\n",
        "        self.data_dir = data_path\n",
        "\n",
        "    @classmethod\n",
        "    def _read_json(cls, input_file):\n",
        "        data = []\n",
        "        with open(input_file, \"r\", encoding='utf-8') as reader:\n",
        "            all_lines = reader.readlines()\n",
        "            for line in all_lines:\n",
        "                ins = eval(line)\n",
        "                data.append(ins)\n",
        "        return data\n",
        "\n",
        "    def get_train_examples(self, data_dir):\n",
        "        \"\"\"See base class.\"\"\"\n",
        "        return self._create_examples(\n",
        "            self._read_json(os.path.join(data_dir, \"train.txt\")), \"train\")\n",
        "\n",
        "    def get_dev_examples(self, data_dir):\n",
        "        \"\"\"See base class.\"\"\"\n",
        "        return self._create_examples(\n",
        "            self._read_json(os.path.join(data_dir, \"val.txt\")), \"dev\")\n",
        "\n",
        "    def get_test_examples(self, data_dir):\n",
        "        \"\"\"See base class.\"\"\"\n",
        "        return self._create_examples(\n",
        "            self._read_json(os.path.join(data_dir, \"test.txt\")), \"test\")\n",
        "\n",
        "    def get_labels(self, negative_label=\"no_relation\"):\n",
        "        data_dir = self.data_dir\n",
        "        \"\"\"See base class.\"\"\"\n",
        "        # if 'k-shot' in self.data_dir:\n",
        "        #     data_dir = os.path.abspath(os.path.join(self.data_dir, \"../..\"))\n",
        "        # else:\n",
        "        #     data_dir = self.data_dir\n",
        "        with open(os.path.join(data_dir,'rel2id.json'), \"r\", encoding='utf-8') as reader:\n",
        "            re2id = json.load(reader)\n",
        "        return re2id\n",
        "\n",
        "\n",
        "    def _create_examples(self, dataset, set_type):\n",
        "        \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
        "        examples = []\n",
        "        for example in dataset:\n",
        "            sentence = example['token']\n",
        "            examples.append(InputExampleWiki80(guid=None,\n",
        "                            sentence=sentence,\n",
        "                            # maybe some bugs here, I don't -1\n",
        "                            span1=(example['h']['pos'][0], example['h']['pos'][1]),\n",
        "                            span2=(example['t']['pos'][0], example['t']['pos'][1]),\n",
        "                            ner1=None,\n",
        "                            ner2=None,\n",
        "                            label=example['relation']))\n",
        "        return examples\n",
        "\n",
        "\n",
        "    def _create_examples(self, dataset, set_type):\n",
        "        \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
        "        examples = []\n",
        "        for example in dataset:\n",
        "            sentence = example['token']\n",
        "            examples.append(InputExampleWiki80(guid=None,\n",
        "                            sentence=sentence,\n",
        "                            # maybe some bugs here, I don't -1\n",
        "                            span1=(example['h']['pos'][0], example['h']['pos'][1]),\n",
        "                            span2=(example['t']['pos'][0], example['t']['pos'][1]),\n",
        "                            ner1=None,\n",
        "                            ner2=None,\n",
        "                            label=example['relation']))\n",
        "        return examples\n",
        "\n",
        "class BaseDataModule(nn.Module):\n",
        "    \"\"\"\n",
        "    Base DataModule.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, cfg) -> None:\n",
        "        super().__init__()\n",
        "        self.cfg = cfg if cfg is not None else {}\n",
        "        self.batch_size = self.cfg.batch_size\n",
        "        self.num_workers = self.cfg.num_workers\n",
        "\n",
        "    def get_data_config(self):\n",
        "        \"\"\"Return important settings of the dataset, which will be passed to instantiate models.\"\"\"\n",
        "        return { \"num_labels\": self.num_labels}\n",
        "\n",
        "    def prepare_data(self):\n",
        "        \"\"\"\n",
        "        Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).\n",
        "        \"\"\"\n",
        "        pass\n",
        "\n",
        "    def setup(self, stage=None):\n",
        "        \"\"\"\n",
        "        Split into train, val, test, and set dims.\n",
        "        Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.\n",
        "        \"\"\"\n",
        "        self.data_train = None\n",
        "        self.data_val = None\n",
        "        self.data_test = None\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
        "\n",
        "    def test_dataloader(self):\n",
        "        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
        "\n",
        "def convert_examples_to_features(examples, max_seq_length, tokenizer, cfg, rel2id):\n",
        "    \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n",
        "\n",
        "    save_file = \"data/cached_wiki80.pkl\"\n",
        "    mode = \"text\"\n",
        "\n",
        "    num_tokens = 0\n",
        "    num_fit_examples = 0\n",
        "    num_shown_examples = 0\n",
        "    instances = []\n",
        "    \n",
        "    use_bert = \"BertTokenizer\" in tokenizer.__class__.__name__\n",
        "    use_gpt = \"GPT\" in tokenizer.__class__.__name__\n",
        "    \n",
        "    assert not (use_bert and use_gpt), \"model cannot be gpt and bert together\"\n",
        "\n",
        "    if False:\n",
        "        with open(file=save_file, mode='rb') as fr:\n",
        "            instances = pickle.load(fr)\n",
        "        print('load preprocessed data from {}.'.format(save_file))\n",
        "\n",
        "    else:\n",
        "        print('loading..')\n",
        "        for (ex_index, example) in enumerate(examples):\n",
        "            \n",
        "\n",
        "            \"\"\"\n",
        "                the relation between SUBJECT and OBJECT is .\n",
        "                \n",
        "            \"\"\"\n",
        "\n",
        "            if ex_index % 10000 == 0:\n",
        "                logger.info(\"Writing example %d of %d\" % (ex_index, len(examples)))\n",
        "\n",
        "            tokens = []\n",
        "            SUBJECT_START = \"[subject_start]\"\n",
        "            SUBJECT_END = \"[subject_end]\"\n",
        "            OBJECT_START = \"[object_start]\"\n",
        "            OBJECT_END = \"[object_end]\"\n",
        "            \n",
        "\n",
        "            if mode.startswith(\"text\"):\n",
        "                for i, token in enumerate(example.sentence):\n",
        "                    if i == example.span1[0]:\n",
        "                        tokens.append(SUBJECT_START)\n",
        "                    if i == example.span2[0]:\n",
        "                        tokens.append(OBJECT_START)\n",
        "                    # for sub_token in tokenizer.tokenize(token):\n",
        "                    #     tokens.append(sub_token)\n",
        "                    if i == example.span1[1]:\n",
        "                        tokens.append(SUBJECT_END)\n",
        "                    if i == example.span2[1]:\n",
        "                        tokens.append(OBJECT_END)\n",
        "\n",
        "                    tokens.append(token)\n",
        "\n",
        "            SUBJECT = \" \".join(example.sentence[example.span1[0]: example.span1[1]])\n",
        "            OBJECT = \" \".join(example.sentence[example.span2[0]: example.span2[1]])\n",
        "            SUBJECT_ids = tokenizer(\" \"+SUBJECT, add_special_tokens=False)['input_ids']\n",
        "            OBJECT_ids = tokenizer(\" \"+OBJECT, add_special_tokens=False)['input_ids']\n",
        "            \n",
        "            if use_gpt:\n",
        "                if cfg.CT_CL:\n",
        "                    prompt = f\"[T1] [T2] [T3] [sub] {OBJECT} [sub] [T4] [obj] {SUBJECT} [obj] [T5] {tokenizer.cls_token}\"\n",
        "                else:\n",
        "                    prompt = f\"The relation between [sub] {SUBJECT} [sub] and [obj] {OBJECT} [obj] is {tokenizer.cls_token} .\"\n",
        "            else:\n",
        "                # add prompt [T_n] and entity marker [obj] to enrich the context.\n",
        "                prompt = f\"[sub] {SUBJECT} [sub] {tokenizer.mask_token} [obj] {OBJECT} [obj] .\"\n",
        "            \n",
        "            if ex_index == 0:\n",
        "                input_text = \" \".join(tokens)\n",
        "                logger.info(f\"input text : {input_text}\")\n",
        "                logger.info(f\"prompt : {prompt}\")\n",
        "                logger.info(f\"label : {example.label}\")\n",
        "            inputs = tokenizer(\n",
        "                prompt,\n",
        "                \" \".join(tokens),\n",
        "                truncation=\"longest_first\",\n",
        "                max_length=max_seq_length,\n",
        "                padding=\"max_length\",\n",
        "                add_special_tokens=True\n",
        "            )\n",
        "            if use_gpt: cls_token_location = inputs['input_ids'].index(tokenizer.cls_token_id) \n",
        "            \n",
        "            # find the subject and object tokens, choose the first ones\n",
        "            sub_st = sub_ed = obj_st = obj_ed = -1\n",
        "            for i in range(len(inputs['input_ids'])):\n",
        "                if sub_st == -1 and inputs['input_ids'][i:i+len(SUBJECT_ids)] == SUBJECT_ids:\n",
        "                    sub_st = i\n",
        "                    sub_ed = i + len(SUBJECT_ids)\n",
        "                if obj_st == -1 and inputs['input_ids'][i:i+len(OBJECT_ids)] == OBJECT_ids:\n",
        "                    obj_st = i\n",
        "                    obj_ed = i + len(OBJECT_ids)\n",
        "            \n",
        "            assert sub_st != -1 and obj_st != -1\n",
        "\n",
        "\n",
        "            num_tokens += sum(inputs['attention_mask'])\n",
        "\n",
        "\n",
        "            if sum(inputs['attention_mask']) > max_seq_length:\n",
        "                pass\n",
        "                # tokens = tokens[:max_seq_length]\n",
        "            else:\n",
        "                num_fit_examples += 1\n",
        "\n",
        "            x = OrderedDict()\n",
        "            x['input_ids'] = inputs['input_ids']\n",
        "            if use_bert: x['token_type_ids'] = inputs['token_type_ids']\n",
        "            x['attention_mask'] = inputs['attention_mask']\n",
        "            x['label'] = rel2id[example.label]\n",
        "            if use_gpt: x['cls_token_location'] = cls_token_location\n",
        "            x['so'] =[sub_st, sub_ed, obj_st, obj_ed]\n",
        "\n",
        "            instances.append(x)\n",
        "\n",
        "\n",
        "        with open(file=save_file, mode='wb') as fw:\n",
        "            pickle.dump(instances, fw)\n",
        "        print('Finish save preprocessed data to {}.'.format( save_file))\n",
        "\n",
        "    input_ids = [o['input_ids'] for o in instances]\n",
        "    attention_mask = [o['attention_mask'] for o in instances]\n",
        "    if use_bert: token_type_ids = [o['token_type_ids'] for o in instances]\n",
        "    if use_gpt: cls_idx = [o['cls_token_location'] for o in instances]\n",
        "    labels = [o['label'] for o in instances]\n",
        "    so = torch.tensor([o['so'] for o in instances])\n",
        "\n",
        "\n",
        "    input_ids = torch.tensor(input_ids)\n",
        "    attention_mask = torch.tensor(attention_mask)\n",
        "    if use_gpt: cls_idx = torch.tensor(cls_idx)\n",
        "    if use_bert: token_type_ids = torch.tensor(token_type_ids)\n",
        "    labels = torch.tensor(labels)\n",
        "\n",
        "    logger.info(\"Average #tokens: %.2f\" % (num_tokens * 1.0 / len(examples)))\n",
        "    logger.info(\"%d (%.2f %%) examples can fit max_seq_length = %d\" % (num_fit_examples,\n",
        "                num_fit_examples * 100.0 / len(examples), max_seq_length))\n",
        "\n",
        "    if use_gpt:\n",
        "        dataset = TensorDataset(input_ids, attention_mask, cls_idx, labels)\n",
        "    elif use_bert:\n",
        "        dataset = TensorDataset(input_ids, attention_mask, token_type_ids, labels, so)\n",
        "    else:\n",
        "        dataset = TensorDataset(input_ids, attention_mask, labels, so)\n",
        "    \n",
        "    return dataset\n",
        "\n",
        "def get_dataset(mode, cfg, tokenizer, processor):\n",
        "\n",
        "    if mode == \"train\":\n",
        "        examples = processor.get_train_examples(cfg.data_dir)\n",
        "    elif mode == \"dev\":\n",
        "        examples = processor.get_dev_examples(cfg.data_dir)\n",
        "    elif mode == \"test\":\n",
        "        examples = processor.get_test_examples(cfg.data_dir)\n",
        "    else:\n",
        "        raise Exception(\"mode must be in choice [trian, dev, test]\")\n",
        "    gpt_mode = \"wiki80\" in cfg.task_name\n",
        "    # normal relation extraction task\n",
        "    dataset = convert_examples_to_features(\n",
        "        examples, cfg.max_seq_length, tokenizer, cfg, processor.get_labels()\n",
        "    )\n",
        "    return dataset\n",
        "\n",
        "class REDataset(BaseDataModule):\n",
        "    def __init__(self, cfg) -> None:\n",
        "        super().__init__(cfg)\n",
        "        \n",
        "        self.cfg = cfg\n",
        "        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)\n",
        "        self.processor = wiki80Processor(self.tokenizer, self.cfg.data_dir, self.cfg.use_prompt)\n",
        "        \n",
        "        use_gpt = \"gpt\" in cfg.model_name_or_path\n",
        "\n",
        "        rel2id = self.processor.get_labels()\n",
        "        self.num_labels = len(rel2id)\n",
        "\n",
        "        entity_list = [\"[object_start]\", \"[object_end]\", \"[subject_start]\", \"[subject_end]\"]\n",
        "        class_list = [f\"[class{i}]\" for i in range(1, self.num_labels+1)]\n",
        "\n",
        "        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})\n",
        "        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})\n",
        "        if use_gpt:\n",
        "            self.tokenizer.add_special_tokens({'cls_token': \"[CLS]\"})\n",
        "            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
        "        so_list = [\"[sub]\", \"[obj]\"]\n",
        "        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})\n",
        "\n",
        "        prompt_tokens = [f\"[T{i}]\" for i in range(1,6)]\n",
        "        self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "    def setup(self, stage=None):\n",
        "        self.data_train = get_dataset(\"train\", self.cfg, self.tokenizer, self.processor)\n",
        "        self.data_val = get_dataset(\"dev\", self.cfg, self.tokenizer, self.processor)\n",
        "        self.data_test = get_dataset(\"test\", self.cfg, self.tokenizer, self.processor)\n",
        "\n",
        "\n",
        "    def prepare_data(self):\n",
        "        pass\n",
        "\n",
        "    def get_tokenizer(self):\n",
        "        return self.tokenizer"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "409840a8",
      "metadata": {
        "id": "409840a8"
      },
      "source": [
        "## Metric Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f5dfdfea",
      "metadata": {
        "id": "f5dfdfea"
      },
      "outputs": [],
      "source": [
        "def dialog_f1_eval(logits, labels):\n",
        "    def getpred(result, T1=0.5, T2=0.4):\n",
        "        # 使用阈值得到preds, result = logits\n",
        "        # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
        "        ret = []\n",
        "        for i in range(len(result)):\n",
        "            r = []\n",
        "            maxl, maxj = -1, -1\n",
        "            for j in range(len(result[i])):\n",
        "                if result[i][j] > T1:\n",
        "                    r += [j]\n",
        "                if result[i][j] > maxl:\n",
        "                    maxl = result[i][j]\n",
        "                    maxj = j\n",
        "            if len(r) == 0:\n",
        "                if maxl <= T2:\n",
        "                    r = [36]\n",
        "                else:\n",
        "                    r += [maxj]\n",
        "            ret.append(r)\n",
        "        return ret\n",
        "\n",
        "    def geteval(devp, data):\n",
        "        correct_sys, all_sys = 0, 0\n",
        "        correct_gt = 0\n",
        "\n",
        "        for i in range(len(data)):\n",
        "            # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1， 如果没有就是[36]\n",
        "            for id in data[i]:\n",
        "                if id != 36:\n",
        "                    # 标签中 1 的个数\n",
        "                    correct_gt += 1\n",
        "                    if id in devp[i]:\n",
        "                        # 预测正确\n",
        "                        correct_sys += 1\n",
        "\n",
        "            for id in devp[i]:\n",
        "                if id != 36:\n",
        "                    all_sys += 1\n",
        "\n",
        "        precision = 1 if all_sys == 0 else correct_sys / all_sys\n",
        "        recall = 0 if correct_gt == 0 else correct_sys / correct_gt\n",
        "        f_1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0\n",
        "        return f_1\n",
        "\n",
        "    logits = np.asarray(logits)\n",
        "    logits = list(1 / (1 + np.exp(-logits)))\n",
        "\n",
        "    temp_labels = []\n",
        "    for l in labels:\n",
        "        t = []\n",
        "        for i in range(36):\n",
        "            if l[i] == 1:\n",
        "                t += [i]\n",
        "        if len(t) == 0:\n",
        "            t = [36]\n",
        "        temp_labels.append(t)\n",
        "    assert (len(labels) == len(logits))\n",
        "    labels = temp_labels\n",
        "\n",
        "    bestT2 = bestf_1 = 0\n",
        "    for T2 in range(51):\n",
        "        devp = getpred(logits, T2=T2 / 100.)\n",
        "        f_1 = geteval(devp, labels)\n",
        "        if f_1 > bestf_1:\n",
        "            bestf_1 = f_1\n",
        "            bestT2 = T2 / 100.\n",
        "\n",
        "    return dict(f1=bestf_1, T2=bestT2)\n",
        "\n",
        "\n",
        "\n",
        "def f1_eval(logits, labels):\n",
        "    def getpred(result, T1 = 0.5, T2 = 0.4) :\n",
        "        # 使用阈值得到preds, result = logits\n",
        "        # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
        "        ret = []\n",
        "        for i in range(len(result)):\n",
        "            r = []\n",
        "            maxl, maxj = -1, -1\n",
        "            for j in range(len(result[i])):\n",
        "                if result[i][j] > T1:\n",
        "                    r += [j]\n",
        "                if result[i][j] > maxl:\n",
        "                    maxl = result[i][j]\n",
        "                    maxj = j\n",
        "            if len(r) == 0:\n",
        "                if maxl <= T2:\n",
        "                    r = [36]\n",
        "                else:\n",
        "                    r += [maxj]\n",
        "            ret.append(r)\n",
        "        return ret\n",
        "\n",
        "    def geteval(devp, data):\n",
        "        correct_sys, all_sys = 0, 0\n",
        "        correct_gt = 0\n",
        "        \n",
        "        for i in range(len(data)):\n",
        "            # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1， 如果没有就是[36]\n",
        "            for id in data[i]:\n",
        "                if id != 36:\n",
        "                    # 标签中 1 的个数\n",
        "                    correct_gt += 1\n",
        "                    if id in devp[i]:\n",
        "                        # 预测正确\n",
        "                        correct_sys += 1\n",
        "\n",
        "            for id in devp[i]:\n",
        "                if id != 36:\n",
        "                    all_sys += 1\n",
        "\n",
        "        precision = 1 if all_sys == 0 else correct_sys/all_sys\n",
        "        recall = 0 if correct_gt == 0 else correct_sys/correct_gt\n",
        "        f_1 = 2*precision*recall/(precision+recall) if precision+recall != 0 else 0\n",
        "        return f_1\n",
        "\n",
        "    logits = np.asarray(logits)\n",
        "    logits = list(1 / (1 + np.exp(-logits)))\n",
        "\n",
        "    temp_labels = []\n",
        "    for l in labels:\n",
        "        t = []\n",
        "        for i in range(36):\n",
        "            if l[i] == 1:\n",
        "                t += [i]\n",
        "        if len(t) == 0:\n",
        "            t = [36]\n",
        "        temp_labels.append(t)\n",
        "    assert(len(labels) == len(logits))\n",
        "    labels = temp_labels\n",
        "    \n",
        "    bestT2 = bestf_1 = 0\n",
        "    for T2 in range(51):\n",
        "        devp = getpred(logits, T2=T2/100.)\n",
        "        f_1 = geteval(devp, labels)\n",
        "        if f_1 > bestf_1:\n",
        "            bestf_1 = f_1\n",
        "            bestT2 = T2/100.\n",
        "\n",
        "    return bestf_1, bestT2\n",
        "\n",
        "\n",
        "def f1_score(output, label, rel_num=42, na_num=13):\n",
        "    correct_by_relation = Counter()\n",
        "    guess_by_relation = Counter()\n",
        "    gold_by_relation = Counter()\n",
        "    output = np.argmax(output, axis=-1)\n",
        "\n",
        "    for i in range(len(output)):\n",
        "        guess = output[i]\n",
        "        gold = label[i]\n",
        "\n",
        "        if guess == na_num:\n",
        "            guess = 0\n",
        "        elif guess < na_num:\n",
        "            guess += 1\n",
        "\n",
        "        if gold == na_num:\n",
        "            gold = 0\n",
        "        elif gold < na_num:\n",
        "            gold += 1\n",
        "\n",
        "        if gold == 0 and guess == 0:\n",
        "            continue\n",
        "        if gold == 0 and guess != 0:\n",
        "            guess_by_relation[guess] += 1\n",
        "        if gold != 0 and guess == 0:\n",
        "            gold_by_relation[gold] += 1\n",
        "        if gold != 0 and guess != 0:\n",
        "            guess_by_relation[guess] += 1\n",
        "            gold_by_relation[gold] += 1\n",
        "            if gold == guess:\n",
        "                correct_by_relation[gold] += 1\n",
        "    \n",
        "    f1_by_relation = Counter()\n",
        "    recall_by_relation = Counter()\n",
        "    prec_by_relation = Counter()\n",
        "    for i in range(1, rel_num):\n",
        "        recall = 0\n",
        "        if gold_by_relation[i] > 0:\n",
        "            recall = correct_by_relation[i] / gold_by_relation[i]\n",
        "        precision = 0\n",
        "        if guess_by_relation[i] > 0:\n",
        "            precision = correct_by_relation[i] / guess_by_relation[i]\n",
        "        if recall + precision > 0 :\n",
        "            f1_by_relation[i] = 2 * recall * precision / (recall + precision)\n",
        "        recall_by_relation[i] = recall\n",
        "        prec_by_relation[i] = precision\n",
        "\n",
        "    micro_f1 = 0\n",
        "    if sum(guess_by_relation.values()) != 0 and sum(correct_by_relation.values()) != 0:\n",
        "        recall = sum(correct_by_relation.values()) / sum(gold_by_relation.values())\n",
        "        prec = sum(correct_by_relation.values()) / sum(guess_by_relation.values())    \n",
        "        micro_f1 = 2 * recall * prec / (recall+prec)\n",
        "\n",
        "    return dict(f1=micro_f1)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7bec83d2",
      "metadata": {
        "id": "7bec83d2"
      },
      "source": [
        "## Model Construction"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "885d21de",
      "metadata": {
        "id": "885d21de"
      },
      "source": [
        "### Base Model Class"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0ce6d9b3",
      "metadata": {
        "id": "0ce6d9b3"
      },
      "outputs": [],
      "source": [
        "OPTIMIZER = \"AdamW\"\n",
        "LR = 5e-5\n",
        "LOSS = \"cross_entropy\"\n",
        "ONE_CYCLE_TOTAL_STEPS = 100\n",
        "\n",
        "class BaseLitModel(nn.Module):\n",
        "    \"\"\"\n",
        "    Generic PyTorch-Lightning class that must be initialized with a PyTorch module.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model, cfg, device: str = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):\n",
        "        super().__init__()\n",
        "        self.model = model\n",
        "        self.cur_model = model.module if hasattr(model, 'module') else model\n",
        "        self.device = device\n",
        "        self.cfg = cfg if cfg is not None else {}\n",
        "\n",
        "        optimizer = self.cfg.optimizer\n",
        "        self.optimizer_class = getattr(torch.optim, optimizer)\n",
        "        self.lr = self.cfg.lr\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)\n",
        "        if self.one_cycle_max_lr is None:\n",
        "            return optimizer\n",
        "        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)\n",
        "        return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        x, y = batch\n",
        "        x.to(self.device)\n",
        "        logits = x\n",
        "        loss = (logits - y) ** 2\n",
        "        print(\"train_loss: \", loss)\n",
        "        #self.train_acc(logits, y)\n",
        "        #self.log(\"train_acc\", self.train_acc, on_step=False, on_epoch=True)\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        x, y = batch\n",
        "        x.to(self.device)\n",
        "        logits = x\n",
        "        loss = (logits - y) ** 2\n",
        "        print(\"val_loss: \", loss)\n",
        "\n",
        "    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        x, y = batch\n",
        "        x.to(self.device)\n",
        "        logits = x\n",
        "        loss = (logits - y) ** 2\n",
        "        print(\"test_loss: \", loss)\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
        "\n",
        "        optimizer_group_parameters = [\n",
        "            {\"params\": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
        "            {\"params\": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
        "        ]\n",
        "\n",
        "        \n",
        "        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
        "        #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * 0.1, num_training_steps=self.num_training_steps)\n",
        "        return optimizer"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "2f55cddc",
      "metadata": {
        "id": "2f55cddc"
      },
      "source": [
        "### Model Subclass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c01fa9d1",
      "metadata": {
        "id": "c01fa9d1"
      },
      "outputs": [],
      "source": [
        "def multilabel_categorical_crossentropy(y_pred, y_true):\n",
        "    y_pred = (1 - 2 * y_true) * y_pred\n",
        "    y_pred_neg = y_pred - y_true * 1e12\n",
        "    y_pred_pos = y_pred - (1 - y_true) * 1e12\n",
        "    zeros = torch.zeros_like(y_pred[..., :1])\n",
        "    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)\n",
        "    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)\n",
        "    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)\n",
        "    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)\n",
        "    return (neg_loss + pos_loss).mean()\n",
        "\n",
        "class BertLitModel(BaseLitModel):\n",
        "    \"\"\"\n",
        "    use AutoModelForMaskedLM, and select the output by another layer in the lit model\n",
        "    \"\"\"\n",
        "    def __init__(self, model, cfg, tokenizer):\n",
        "        super().__init__(model, cfg)\n",
        "        self.tokenizer = tokenizer\n",
        "        with open(f\"{cfg.data_dir}/rel2id.json\",\"r\") as file:\n",
        "            rel2id = json.load(file)\n",
        "        \n",
        "        Na_num = 0\n",
        "        for k, v in rel2id.items():\n",
        "            if k == \"NA\" or k == \"no_relation\" or k == \"Other\":\n",
        "                Na_num = v\n",
        "                break\n",
        "        num_relation = len(rel2id)\n",
        "        # init loss function\n",
        "        self.loss_fn = multilabel_categorical_crossentropy if \"dialogue\" in cfg.data_dir else nn.CrossEntropyLoss()\n",
        "        # ignore the no_relation class to compute the f1 score\n",
        "        self.eval_fn = f1_eval if \"dialogue\" in cfg.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num)\n",
        "        self.best_f1 = 0\n",
        "        self.t_lambda = cfg.t_lambda\n",
        "        \n",
        "        self.label_st_id = tokenizer(\"[class1]\", add_special_tokens=False)['input_ids'][0]\n",
        "    \n",
        "        self._init_label_word()\n",
        "\n",
        "    def _init_label_word(self):\n",
        "        cfg = self.cfg\n",
        "        # ./dataset/dataset_name\n",
        "        dataset_name = cfg.data_dir.split(\"/\")[1]\n",
        "        model_name_or_path = cfg.model_name_or_path.split(\"/\")[-1]\n",
        "        label_path = f\"data/{model_name_or_path}.pt\"\n",
        "        # [num_labels, num_tokens], ignore the unanswerable\n",
        "        if \"dialogue\" in cfg.data_dir:\n",
        "            label_word_idx = torch.load(label_path)[:-1]\n",
        "        else:\n",
        "            label_word_idx = torch.load(label_path)\n",
        "        \n",
        "        num_labels = len(label_word_idx)\n",
        "        \n",
        "        self.cur_model.resize_token_embeddings(len(self.tokenizer))\n",
        "        with torch.no_grad():\n",
        "            word_embeddings = self.cur_model.get_input_embeddings()\n",
        "            continous_label_word = [a[0] for a in self.tokenizer([f\"[class{i}]\" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']]\n",
        "            for i, idx in enumerate(label_word_idx):\n",
        "                word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)\n",
        "                # word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i]\n",
        "            so_word = [a[0] for a in self.tokenizer([\"[obj]\",\"[sub]\"], add_special_tokens=False)['input_ids']]\n",
        "            meaning_word = [a[0] for a in self.tokenizer([\"person\",\"organization\", \"location\", \"date\", \"country\"], add_special_tokens=False)['input_ids']]\n",
        "            \n",
        "            for i, idx in enumerate(so_word):\n",
        "                word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0)\n",
        "            assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight)\n",
        "            assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight)\n",
        "        \n",
        "        self.word2label = continous_label_word # a continous list\n",
        "            \n",
        "                \n",
        "    def forward(self, x):\n",
        "        return self.model(x)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        input_ids, attention_mask, token_type_ids , labels, so = batch\n",
        "        input_ids = input_ids.to(self.device)\n",
        "        attention_mask = attention_mask.to(self.device)\n",
        "        token_type_ids = token_type_ids.to(self.device)\n",
        "        labels = labels.to(self.device)\n",
        "        so = so.to(self.device)\n",
        "        result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True)\n",
        "        logits = result.logits\n",
        "        output_embedding = result.hidden_states[-1]\n",
        "        logits = self.pvp(logits, input_ids)\n",
        "        loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so)\n",
        "        #print(\"Train/loss: \", loss)\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
        "        input_ids = input_ids.to(self.device)\n",
        "        attention_mask = attention_mask.to(self.device)\n",
        "        token_type_ids = token_type_ids.to(self.device)\n",
        "        labels = labels.to(self.device)\n",
        "        logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
        "        logits = self.pvp(logits, input_ids)\n",
        "        loss = self.loss_fn(logits, labels)\n",
        "        #print(\"Eval/loss: \", loss)\n",
        "        return {\"loss\": loss, \"eval_logits\": logits.detach().cpu().numpy(), \"eval_labels\": labels.detach().cpu().numpy()}\n",
        "    \n",
        "    def validation_epoch_end(self, outputs):\n",
        "        logits = np.concatenate([o[\"eval_logits\"] for o in outputs])\n",
        "        labels = np.concatenate([o[\"eval_labels\"] for o in outputs])\n",
        "\n",
        "        f1 = self.eval_fn(logits, labels)['f1']\n",
        "        #print(\"Eval/f1: \", f1)\n",
        "        best_f1 = -1\n",
        "        if f1 > self.best_f1:\n",
        "            self.best_f1 = f1\n",
        "            best_f1 = self.best_f1\n",
        "        #print(\"Eval/best_f1: \", self.best_f1)\n",
        "        return f1, best_f1, self.best_f1\n",
        "\n",
        "    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument\n",
        "        input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
        "        input_ids = input_ids.to(self.device)\n",
        "        attention_mask = attention_mask.to(self.device)\n",
        "        token_type_ids = token_type_ids.to(self.device)\n",
        "        labels = labels.to(self.device)\n",
        "        logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
        "        logits = self.pvp(logits, input_ids)\n",
        "        return {\"test_logits\": logits.detach().cpu().numpy(), \"test_labels\": labels.detach().cpu().numpy()}\n",
        "\n",
        "    def test_epoch_end(self, outputs):\n",
        "        logits = np.concatenate([o[\"test_logits\"] for o in outputs])\n",
        "        labels = np.concatenate([o[\"test_labels\"] for o in outputs])\n",
        "\n",
        "        f1 = self.eval_fn(logits, labels)['f1']\n",
        "        #print(\"Test/f1: \", f1)\n",
        "        return f1\n",
        "        \n",
        "    def pvp(self, logits, input_ids):\n",
        "        # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]\n",
        "        mask_id = self.tokenizer(self.tokenizer.mask_token, add_special_tokens = False)['input_ids'][0]\n",
        "        _, mask_idx = (input_ids == mask_id).nonzero(as_tuple=True)\n",
        "        bs = input_ids.shape[0]\n",
        "        mask_output = logits[torch.arange(bs), mask_idx]\n",
        "        assert mask_idx.shape[0] == bs, \"only one mask in sequence!\"\n",
        "        final_output = mask_output[:,self.word2label]\n",
        "        \n",
        "        return final_output\n",
        "        \n",
        "    def ke_loss(self, logits, labels, so):\n",
        "        subject_embedding = []\n",
        "        object_embedding = []\n",
        "        bsz = logits.shape[0]\n",
        "        for i in range(bsz):\n",
        "            subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0))\n",
        "            object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0))\n",
        "            \n",
        "        subject_embedding = torch.stack(subject_embedding)\n",
        "        object_embedding = torch.stack(object_embedding)\n",
        "        # trick , the relation ids is concated, \n",
        "        relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id]\n",
        "        \n",
        "        loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2)\n",
        "        \n",
        "        return loss\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
        "\n",
        "        if not self.cfg.two_steps: \n",
        "            parameters = self.cur_model.named_parameters()\n",
        "        else:\n",
        "            # cur_model.bert.embeddings.weight\n",
        "            parameters = [next(self.cur_model.named_parameters())]\n",
        "        # only optimize the embedding parameters\n",
        "        optimizer_group_parameters = [\n",
        "            {\"params\": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
        "            {\"params\": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
        "        ]\n",
        "\n",
        "        \n",
        "        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
        "        return optimizer"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4e4baf1f",
      "metadata": {
        "id": "4e4baf1f"
      },
      "source": [
        "## Preprocess the inputs"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "53752c58",
      "metadata": {
        "id": "53752c58"
      },
      "source": [
        "### Few-shot sampling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "95353afa",
      "metadata": {
        "id": "95353afa"
      },
      "outputs": [],
      "source": [
        "Seed = [1, 2, 3, 4, 5]\n",
        "mode = 'k-shot'\n",
        "data_file = 'train.txt'\n",
        "\n",
        "def get_labels(path, name,  negative_label=\"no_relation\"):\n",
        "    \"\"\"See base class.\"\"\"\n",
        "\n",
        "    count = Counter()\n",
        "    with open(path + \"/\" + name, \"r\") as f:\n",
        "        features = []\n",
        "        for line in f.readlines():\n",
        "            line = line.rstrip()\n",
        "            if len(line) > 0:\n",
        "                # count[line['relation']] += 1\n",
        "                features.append(eval(line))\n",
        "\n",
        "    # logger.info(\"label distribution as list: %d labels\" % len(count))\n",
        "    # # Make sure the negative label is alwyas 0\n",
        "    # labels = []\n",
        "    # for label, count in count.most_common():\n",
        "    #     logger.info(\"%s: %d 个 %.2f%%\" % (label, count,  count * 100.0 / len(dataset)))\n",
        "    #     if label not in labels:\n",
        "    #         labels.append(label)\n",
        "    return features\n",
        "\n",
        "path = 'data'\n",
        "\n",
        "output_dir = os.path.join(path, mode)\n",
        "dataset = get_labels(path, data_file)\n",
        "\n",
        "for seed in Seed:\n",
        "\n",
        "    # Other datasets\n",
        "    np.random.seed(seed)\n",
        "    np.random.shuffle(dataset)\n",
        "\n",
        "    # Set up dir\n",
        "    k = 8\n",
        "    setting_dir = os.path.join(output_dir, f\"{k}-{seed}\")\n",
        "    os.makedirs(setting_dir, exist_ok=True)\n",
        "\n",
        "    label_list = {}\n",
        "    for line in dataset:\n",
        "        label = line['relation']\n",
        "        if label not in label_list:\n",
        "            label_list[label] = [line]\n",
        "        else:\n",
        "            label_list[label].append(line)\n",
        "\n",
        "    with open(os.path.join(setting_dir, \"train.txt\"), \"w\") as f:\n",
        "        file_list = []\n",
        "        for label in label_list:\n",
        "            for line in label_list[label][:k]:  # train中每一类取前k个数据\n",
        "                f.writelines(json.dumps(line))\n",
        "                f.write('\\n')\n",
        "\n",
        "        f.close()\n",
        "\n",
        "shutil.copyfile('data/rel2id.json','data/k-shot/8-1/rel2id.json')\n",
        "shutil.copyfile('data/val.txt','data/k-shot/8-1/val.txt')\n",
        "shutil.copyfile('data/test.txt','data/k-shot/8-1/test.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "835ff58f",
      "metadata": {
        "id": "835ff58f"
      },
      "source": [
        "### Obtain labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "25654268",
      "metadata": {
        "id": "25654268"
      },
      "outputs": [],
      "source": [
        "def split_label_words(tokenizer, label_list):\n",
        "    label_word_list = []\n",
        "    for label in label_list:\n",
        "        if label == 'no_relation':\n",
        "            label_word_id = tokenizer.encode('None', add_special_tokens=False)\n",
        "            label_word_list.append(torch.tensor(label_word_id))\n",
        "        else:\n",
        "            tmps = label\n",
        "            label = label.lower()\n",
        "            label = label.split(\"(\")[0]\n",
        "            label = label.replace(\":\",\" \").replace(\"_\",\" \").replace(\"per\",\"person\").replace(\"org\",\"organization\")\n",
        "            label_word_id = tokenizer(label, add_special_tokens=False)['input_ids']\n",
        "            print(label, label_word_id)\n",
        "            label_word_list.append(torch.tensor(label_word_id))\n",
        "    padded_label_word_list = pad_sequence([x for x in label_word_list], batch_first=True, padding_value=0)\n",
        "    return padded_label_word_list\n",
        "\n",
        "\n",
        "model_name_or_path = cfg.model_name_or_path\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
        "with open(\"data/rel2id.json\", \"r\") as file:\n",
        "    t = json.load(file)\n",
        "    label_list = list(t)\n",
        "\n",
        "t = split_label_words(tokenizer, label_list)\n",
        "\n",
        "with open(f\"data/{model_name_or_path}.pt\", \"wb\") as file:\n",
        "    torch.save(t, file)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7f95f210",
      "metadata": {
        "id": "7f95f210"
      },
      "source": [
        "## Auxiliary functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0f0ea75b",
      "metadata": {
        "id": "0f0ea75b"
      },
      "outputs": [],
      "source": [
        "def set_seed(cfg):\n",
        "    torch.cuda.manual_seed_all(cfg.seed)\n",
        "    np.random.seed(cfg.seed)\n",
        "    torch.manual_seed(cfg.seed)\n",
        "    torch.cuda.manual_seed_all(cfg.seed)\n",
        "\n",
        "def logging(log_dir, s, print_=True, log_=True):\n",
        "    if print_:\n",
        "        print(s)\n",
        "    if log_dir != '' and log_:\n",
        "        with open(log_dir, 'a+') as f_log:\n",
        "            f_log.write(s + '\\n')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "19fbd3ac",
      "metadata": {
        "id": "19fbd3ac"
      },
      "source": [
        "## Train the model\n",
        "### Model training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3911f6f8",
      "metadata": {
        "id": "3911f6f8"
      },
      "outputs": [],
      "source": [
        "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "data = REDataset(cfg)\n",
        "data_config = data.get_data_config()\n",
        "\n",
        "config = AutoConfig.from_pretrained(cfg.model_name_or_path)\n",
        "config.num_labels = data_config[\"num_labels\"]\n",
        "\n",
        "model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)\n",
        "\n",
        "\n",
        "    \n",
        "# if torch.cuda.device_count() > 1:\n",
        "#     print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n",
        "#     model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))\n",
        "\n",
        "model.to(device)\n",
        "\n",
        "lit_model = BertLitModel(model, cfg, data.tokenizer)\n",
        "data.setup()\n",
        "\n",
        "if cfg.train_from_saved_model != '':\n",
        "    model.load_state_dict(torch.load(cfg.train_from_saved_model)[\"checkpoint\"])\n",
        "    print(\"load saved model from {}.\".format(cfg.train_from_saved_model))\n",
        "    lit_model.best_f1 = torch.load(cfg.train_from_saved_model)[\"best_f1\"]\n",
        "#data.tokenizer.save_pretrained('test')\n",
        "\n",
        "\n",
        "optimizer = lit_model.configure_optimizers()\n",
        "if cfg.train_from_saved_model != '':\n",
        "    optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)[\"optimizer\"])\n",
        "    print(\"load saved optimizer from {}.\".format(cfg.train_from_saved_model))\n",
        "\n",
        "num_training_steps = len(data.train_dataloader()) // cfg.gradient_accumulation_steps * cfg.num_train_epochs\n",
        "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)\n",
        "log_step = 20\n",
        "\n",
        "\n",
        "logging(cfg.log_dir,'-' * 89, print_=True)\n",
        "logging(cfg.log_dir, time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()) + ' INFO : START TO TRAIN ', print_=True)\n",
        "logging(cfg.log_dir,'-' * 89, print_=True)\n",
        "\n",
        "for epoch in range(cfg.num_train_epochs):\n",
        "    model.train()\n",
        "    num_batch = len(data.train_dataloader())\n",
        "    total_loss = 0\n",
        "    log_loss = 0\n",
        "    for index, train_batch in enumerate(data.train_dataloader()):\n",
        "        loss = lit_model.training_step(train_batch, index)\n",
        "        total_loss += loss.item()\n",
        "        log_loss += loss.item()\n",
        "        loss.backward()\n",
        "\n",
        "        optimizer.step()\n",
        "        scheduler.step()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        if log_step > 0 and (index+1) % log_step == 0:\n",
        "            cur_loss = log_loss / log_step\n",
        "            logging(cfg.log_dir, \n",
        "                '| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(\n",
        "                    epoch, (index+1), scheduler.get_last_lr(), cur_loss)\n",
        "                , print_=True)\n",
        "            log_loss = 0\n",
        "    avrg_loss = total_loss / num_batch\n",
        "    logging(cfg.log_dir,\n",
        "        '| epoch {:2d} | train loss {:5.3f}'.format(\n",
        "            epoch, avrg_loss))\n",
        "        \n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        val_loss = []\n",
        "        for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):\n",
        "            loss = lit_model.validation_step(val_batch, val_index)\n",
        "            val_loss.append(loss)\n",
        "        f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)\n",
        "        logging(cfg.log_dir,'-' * 89)\n",
        "        logging(cfg.log_dir,\n",
        "            '| epoch {:2d} | dev_result: {}'.format(epoch, f1))\n",
        "        logging(cfg.log_dir,'-' * 89)\n",
        "        logging(cfg.log_dir,\n",
        "            '| best_f1: {}'.format(best_f1))\n",
        "        logging(cfg.log_dir,'-' * 89)\n",
        "        if cfg.save_path != \"\" and best != -1:\n",
        "            save_path = cfg.save_path\n",
        "            torch.save({\n",
        "                'epoch': epoch,\n",
        "                'checkpoint': model.state_dict(),\n",
        "                'best_f1': best_f1,\n",
        "                'optimizer': optimizer.state_dict()\n",
        "            }, save_path\n",
        "            , _use_new_zipfile_serialization=False)\n",
        "            logging(cfg.log_dir,\n",
        "                '| successfully save model at: {}'.format(save_path))\n",
        "            logging(cfg.log_dir,'-' * 89)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b1118569",
      "metadata": {
        "id": "b1118569"
      },
      "source": [
        "### Model prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3f9292a2",
      "metadata": {
        "id": "3f9292a2"
      },
      "outputs": [],
      "source": [
        "def test(cfg, model, lit_model, data):\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        test_loss = []\n",
        "        for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):\n",
        "            loss = lit_model.test_step(test_batch, test_index)\n",
        "            test_loss.append(loss)\n",
        "        f1 = lit_model.test_epoch_end(test_loss)\n",
        "        logging(cfg.log_dir,\n",
        "            '| test_result: {}'.format(f1))\n",
        "        logging(cfg.log_dir,'-' * 89)\n",
        "\n",
        "test(cfg, model, lit_model, data)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "fewshot_re_tutorial.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
